diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ebdb1ae4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,169 @@ +# Ignore vscode +/.vscode +/DB +/models + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +#MacOS +.DS_Store +SOURCE_DOCUMENTS/.DS_Store \ No newline at end of file diff --git a/constants.py b/constants.py index a8e233f5..9e91b502 100644 --- a/constants.py +++ b/constants.py @@ -23,6 +23,9 @@ is_persistent=True, ) +# Context Window and Max New Tokens +CONTEXT_WINDOW_SIZE = 4096 +MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE#int(CONTEXT_WINDOW_SIZE/4) # https://python.langchain.com/en/latest/_modules/langchain/document_loaders/excel.html#UnstructuredExcelLoader DOCUMENT_MAP = { @@ -70,13 +73,28 @@ #### 32b 130 GB 65 GB 32.5 GB - 35 GB 16.25 GB - 19 GB #### 65b 260.8 GB 130.4 GB 65.2 GB - 67 GB 32.6 GB - - 35 GB -MODEL_ID = "TheBloke/Llama-2-7B-Chat-GGML" -MODEL_BASENAME = "llama-2-7b-chat.ggmlv3.q4_0.bin" +# MODEL_ID = "TheBloke/Llama-2-7B-Chat-GGML" +# MODEL_BASENAME = "llama-2-7b-chat.ggmlv3.q4_0.bin" + +#### +#### (FOR GGUF MODELS) +#### + +# MODEL_ID = "TheBloke/Llama-2-13b-Chat-GGUF" +# MODEL_BASENAME = "llama-2-13b-chat.Q4_K_M.gguf" + +MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF" +MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf" + +# MODEL_ID = "TheBloke/Llama-2-70b-Chat-GGUF" +# MODEL_BASENAME = "llama-2-70b-chat.Q4_K_M.gguf" #### #### (FOR HF MODELS) #### +# MODEL_ID = "NousResearch/Llama-2-7b-chat-hf" +# MODEL_BASENAME = None # MODEL_ID = "TheBloke/vicuna-7B-1.1-HF" # MODEL_BASENAME = None # MODEL_ID = "TheBloke/Wizard-Vicuna-7B-Uncensored-HF" @@ -92,43 +110,43 @@ ##### 48GB VRAM Graphics Cards (RTX 6000, RTX A6000 and other 48GB VRAM GPUs) ##### ### 65b GPTQ LLM Models for 48GB GPUs (*** With best embedding model: hkunlp/instructor-xl ***) -# model_id = "TheBloke/guanaco-65B-GPTQ" -# model_basename = "model.safetensors" -# model_id = "TheBloke/Airoboros-65B-GPT4-2.0-GPTQ" -# model_basename = "model.safetensors" -# model_id = "TheBloke/gpt4-alpaca-lora_mlp-65B-GPTQ" -# model_basename = "model.safetensors" -# model_id = "TheBloke/Upstage-Llama1-65B-Instruct-GPTQ" -# model_basename = "model.safetensors" +# MODEL_ID = "TheBloke/guanaco-65B-GPTQ" +# MODEL_BASENAME = "model.safetensors" +# MODEL_ID = "TheBloke/Airoboros-65B-GPT4-2.0-GPTQ" +# MODEL_BASENAME = "model.safetensors" +# MODEL_ID = "TheBloke/gpt4-alpaca-lora_mlp-65B-GPTQ" +# MODEL_BASENAME = "model.safetensors" +# MODEL_ID = "TheBloke/Upstage-Llama1-65B-Instruct-GPTQ" +# MODEL_BASENAME = "model.safetensors" ##### 24GB VRAM Graphics Cards (RTX 3090 - RTX 4090 (35% Faster) - RTX A5000 - RTX A5500) ##### ### 13b GPTQ Models for 24GB GPUs (*** With best embedding model: hkunlp/instructor-xl ***) -# model_id = "TheBloke/Wizard-Vicuna-13B-Uncensored-GPTQ" -# model_basename = "Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors" -# model_id = "TheBloke/vicuna-13B-v1.5-GPTQ" -# model_basename = "model.safetensors" -# model_id = "TheBloke/Nous-Hermes-13B-GPTQ" -# model_basename = "nous-hermes-13b-GPTQ-4bit-128g.no-act.order" -# model_id = "TheBloke/WizardLM-13B-V1.2-GPTQ" -# model_basename = "gptq_model-4bit-128g.safetensors +# MODEL_ID = "TheBloke/Wizard-Vicuna-13B-Uncensored-GPTQ" +# MODEL_BASENAME = "Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors" +# MODEL_ID = "TheBloke/vicuna-13B-v1.5-GPTQ" +# MODEL_BASENAME = "model.safetensors" +# MODEL_ID = "TheBloke/Nous-Hermes-13B-GPTQ" +# MODEL_BASENAME = "nous-hermes-13b-GPTQ-4bit-128g.no-act.order" +# MODEL_ID = "TheBloke/WizardLM-13B-V1.2-GPTQ" +# MODEL_BASENAME = "gptq_model-4bit-128g.safetensors ### 30b GPTQ Models for 24GB GPUs (*** Requires using intfloat/e5-base-v2 instead of hkunlp/instructor-large as embedding model ***) -# model_id = "TheBloke/Wizard-Vicuna-30B-Uncensored-GPTQ" -# model_basename = "Wizard-Vicuna-30B-Uncensored-GPTQ-4bit--1g.act.order.safetensors" -# model_id = "TheBloke/WizardLM-30B-Uncensored-GPTQ" -# model_basename = "WizardLM-30B-Uncensored-GPTQ-4bit.act-order.safetensors" +# MODEL_ID = "TheBloke/Wizard-Vicuna-30B-Uncensored-GPTQ" +# MODEL_BASENAME = "Wizard-Vicuna-30B-Uncensored-GPTQ-4bit--1g.act.order.safetensors" +# MODEL_ID = "TheBloke/WizardLM-30B-Uncensored-GPTQ" +# MODEL_BASENAME = "WizardLM-30B-Uncensored-GPTQ-4bit.act-order.safetensors" ##### 8-10GB VRAM Graphics Cards (RTX 3080 - RTX 3080 Ti - RTX 3070 Ti - 3060 Ti - RTX 2000 Series, Quadro RTX 4000, 5000, 6000) ##### ### (*** Requires using intfloat/e5-small-v2 instead of hkunlp/instructor-large as embedding model ***) ### 7b GPTQ Models for 8GB GPUs -# model_id = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ" -# model_basename = "Wizard-Vicuna-7B-Uncensored-GPTQ-4bit-128g.no-act.order.safetensors" -# model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ" -# model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors" -# model_id = "TheBloke/wizardLM-7B-GPTQ" -# model_basename = "wizardLM-7B-GPTQ-4bit.compat.no-act-order.safetensors" +# MODEL_ID = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ" +# MODEL_BASENAME = "Wizard-Vicuna-7B-Uncensored-GPTQ-4bit-128g.no-act.order.safetensors" +# MODEL_ID = "TheBloke/WizardLM-7B-uncensored-GPTQ" +# MODEL_BASENAME = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors" +# MODEL_ID = "TheBloke/wizardLM-7B-GPTQ" +# MODEL_BASENAME = "wizardLM-7B-GPTQ-4bit.compat.no-act-order.safetensors" #### #### (FOR GGML) (Quantized cpu+gpu+mps) models - check if they support llama.cpp diff --git a/load_models.py b/load_models.py new file mode 100644 index 00000000..95cf3256 --- /dev/null +++ b/load_models.py @@ -0,0 +1,159 @@ + +import torch +from auto_gptq import AutoGPTQForCausalLM +from huggingface_hub import hf_hub_download +from langchain.llms import LlamaCpp + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + LlamaForCausalLM, + LlamaTokenizer, +) +from constants import ( + CONTEXT_WINDOW_SIZE, + MAX_NEW_TOKENS +) + +def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging): + """ + Load a GGUF/GGML quantized model using LlamaCpp. + + This function attempts to load a GGUF/GGML quantized model using the LlamaCpp library. + If the model is of type GGML, and newer version of LLAMA-CPP is used which does not support GGML, + it logs a message indicating that LLAMA-CPP has dropped support for GGML. + + Parameters: + - model_id (str): The identifier for the model on HuggingFace Hub. + - model_basename (str): The base name of the model file. + - device_type (str): The type of device where the model will run, e.g., 'mps', 'cuda', etc. + - logging (logging.Logger): Logger instance for logging messages. + + Returns: + - LlamaCpp: An instance of the LlamaCpp model if successful, otherwise None. + + Notes: + - The function uses the `hf_hub_download` function to download the model from the HuggingFace Hub. + - The number of GPU layers is set based on the device type. + """ + + try: + logging.info("Using Llamacpp for GGUF/GGML quantized models") + model_path = hf_hub_download( + repo_id=model_id, + filename=model_basename, + resume_download=True, + cache_dir="./models", + ) + kwargs = { + "model_path": model_path, + "n_ctx": CONTEXT_WINDOW_SIZE, + "max_tokens": MAX_NEW_TOKENS, + } + if device_type.lower() == "mps": + kwargs["n_gpu_layers"] = 1 + if device_type.lower() == "cuda": + kwargs["n_gpu_layers"] = 100 # set this based on your GPU + + return LlamaCpp(**kwargs) + except: + if 'ggml' in model_basename: + logging.INFO("If you were using GGML model, LLAMA-CPP Dropped Support, Use GGUF Instead") + return None + +def load_quantized_model_qptq(model_id, model_basename, device_type, logging): + + """ + Load a GPTQ quantized model using AutoGPTQForCausalLM. + + This function loads a quantized model that ends with GPTQ and may have variations + of .no-act.order or .safetensors in their HuggingFace repo. + + Parameters: + - model_id (str): The identifier for the model on HuggingFace Hub. + - model_basename (str): The base name of the model file. + - device_type (str): The type of device where the model will run. + - logging (logging.Logger): Logger instance for logging messages. + + Returns: + - model (AutoGPTQForCausalLM): The loaded quantized model. + - tokenizer (AutoTokenizer): The tokenizer associated with the model. + + Notes: + - The function checks for the ".safetensors" ending in the model_basename and removes it if present. + """ + + # The code supports all huggingface models that ends with GPTQ and have some variation + # of .no-act.order or .safetensors in their HF repo. + logging.info("Using AutoGPTQForCausalLM for quantized models") + + if ".safetensors" in model_basename: + # Remove the ".safetensors" ending if present + model_basename = model_basename.replace(".safetensors", "") + + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + logging.info("Tokenizer loaded") + + model = AutoGPTQForCausalLM.from_quantized( + model_id, + model_basename=model_basename, + use_safetensors=True, + trust_remote_code=True, + device_map="auto", + use_triton=False, + quantize_config=None, + ) + return model, tokenizer + +def load_full_model(model_id, model_basename, device_type, logging): + + """ + Load a full model using either LlamaTokenizer or AutoModelForCausalLM. + + This function loads a full model based on the specified device type. + If the device type is 'mps' or 'cpu', it uses LlamaTokenizer and LlamaForCausalLM. + Otherwise, it uses AutoModelForCausalLM. + + Parameters: + - model_id (str): The identifier for the model on HuggingFace Hub. + - model_basename (str): The base name of the model file. + - device_type (str): The type of device where the model will run. + - logging (logging.Logger): Logger instance for logging messages. + + Returns: + - model (Union[LlamaForCausalLM, AutoModelForCausalLM]): The loaded model. + - tokenizer (Union[LlamaTokenizer, AutoTokenizer]): The tokenizer associated with the model. + + Notes: + - The function uses the `from_pretrained` method to load both the model and the tokenizer. + - Additional settings are provided for NVIDIA GPUs, such as loading in 4-bit and setting the compute dtype. + """ + + if device_type.lower() in ["mps", "cpu"]: + logging.info("Using LlamaTokenizer") + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + cache_dir="./models/" + ) + model = LlamaForCausalLM.from_pretrained( + model_id, + cache_dir="./models/" + ) + else: + logging.info("Using AutoModelForCausalLM for full models") + tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/") + logging.info("Tokenizer loaded") + model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + cache_dir="./models/", + # trust_remote_code=True, # set these if you are using NVIDIA GPU + # load_in_4bit=True, + # bnb_4bit_quant_type="nf4", + # bnb_4bit_compute_dtype=torch.float16, + # max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors + ) + model.tie_weights() + return model, tokenizer \ No newline at end of file diff --git a/prompt_template_utils.py b/prompt_template_utils.py new file mode 100644 index 00000000..72e5aa1f --- /dev/null +++ b/prompt_template_utils.py @@ -0,0 +1,57 @@ +''' +This file implements prompt template for llama based models. +Modify the prompt template based on the model you select. +This seems to have significant impact on the output of the LLM. +''' + +from langchain.memory import ConversationBufferMemory +from langchain.prompts import PromptTemplate + +# this is specific to Llama-2. + +system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions. +Read the given context before answering questions and think step by step. If you can not answer a user question based on +the provided context, inform the user. Do not use any other information for answering user""" + + +def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False): + + if promptTemplate_type=="llama": + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS + if history: + instruction = """ + Context: {history} \n {context} + User: {question}""" + + prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST + prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template) + else: + instruction = """ + Context: {context} + User: {question}""" + + prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST + prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template) + + else: + # change this based on the model you have selected. + if history: + prompt_template = system_prompt + """ + + Context: {history} \n {context} + User: {question} + Answer:""" + prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template) + else: + prompt_template = system_prompt + """ + + Context: {context} + User: {question} + Answer:""" + prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template) + + memory = ConversationBufferMemory(input_key="question", memory_key="history") + + return prompt, memory, diff --git a/run_localGPT_v2.py b/run_localGPT_v2.py new file mode 100644 index 00000000..ba94905c --- /dev/null +++ b/run_localGPT_v2.py @@ -0,0 +1,251 @@ +import os +import logging +import click +import torch +from langchain.chains import RetrievalQA +from langchain.embeddings import HuggingFaceInstructEmbeddings +from langchain.llms import HuggingFacePipeline +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler # for streaming response +from langchain.callbacks.manager import CallbackManager +callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) + +from prompt_template_utils import get_prompt_template + +# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.vectorstores import Chroma +from transformers import (GenerationConfig, + pipeline, + ) + +from load_models import (load_quantized_model_gguf_ggml, + load_quantized_model_qptq, + load_full_model, + ) + +from constants import (EMBEDDING_MODEL_NAME, + PERSIST_DIRECTORY, + MODEL_ID, + MODEL_BASENAME, + MAX_NEW_TOKENS + ) + +def load_model(device_type, model_id, model_basename=None, LOGGING=logging): + """ + Select a model for text generation using the HuggingFace library. + If you are running this for the first time, it will download a model for you. + subsequent runs will use the model from the disk. + + Args: + device_type (str): Type of device to use, e.g., "cuda" for GPU or "cpu" for CPU. + model_id (str): Identifier of the model to load from HuggingFace's model hub. + model_basename (str, optional): Basename of the model if using quantized models. + Defaults to None. + + Returns: + HuggingFacePipeline: A pipeline object for text generation using the loaded model. + + Raises: + ValueError: If an unsupported model or device type is provided. + """ + logging.info(f"Loading Model: {model_id}, on: {device_type}") + logging.info("This action can take a few minutes!") + + if model_basename is not None: + if ".gguf" in model_basename: + llm = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING) + return llm + elif ".ggml" in model_basename.lower(): + model, tokenizer = load_quantized_model_gguf_ggml(model_id, model_basename, device_type, LOGGING) + # return llm + else: + model, tokenizer = load_quantized_model_qptq(model_id, model_basename, device_type, LOGGING) + else: + model, tokenizer = load_full_model(model_id, model_basename, device_type, LOGGING) + + # Load configuration from the model to avoid warnings + generation_config = GenerationConfig.from_pretrained(model_id) + # see here for details: + # https://huggingface.co/docs/transformers/ + # main_classes/text_generation#transformers.GenerationConfig.from_pretrained.returns + + # Create a pipeline for text generation + pipe = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + max_length=MAX_NEW_TOKENS, + temperature=0.2, + # top_p=0.95, + repetition_penalty=1.15, + generation_config=generation_config, + ) + + local_llm = HuggingFacePipeline(pipeline=pipe) + logging.info("Local LLM Loaded") + + return local_llm + + +def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"): + """ + Initializes and returns a retrieval-based Question Answering (QA) pipeline. + + This function sets up a QA system that retrieves relevant information using embeddings + from the HuggingFace library. It then answers questions based on the retrieved information. + + Parameters: + - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'cuda', etc. + - use_history (bool): Flag to determine whether to use chat history or not. + + Returns: + - RetrievalQA: An initialized retrieval-based QA system. + + Notes: + - The function uses embeddings from the HuggingFace library, either instruction-based or regular. + - The Chroma class is used to load a vector store containing pre-computed embeddings. + - The retriever fetches relevant documents or data based on a query. + - The prompt and memory, obtained from the `get_prompt_template` function, might be used in the QA system. + - The model is loaded onto the specified device using its ID and basename. + - The QA system retrieves relevant documents using the retriever and then answers questions based on those documents. + """ + + embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": device_type}) + # uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py + # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) + + # load the vectorstore + db = Chroma(persist_directory=PERSIST_DIRECTORY, + embedding_function=embeddings,) + retriever = db.as_retriever() + + # get the prompt template and memory if set by the user. + prompt, memory = get_prompt_template(promptTemplate_type=promptTemplate_type, + history=use_history) + + # load the llm pipeline + llm = load_model(device_type, + model_id=MODEL_ID, + model_basename=MODEL_BASENAME, + LOGGING=logging) + + if use_history: + qa = RetrievalQA.from_chain_type(llm=llm, + chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank + retriever=retriever, + return_source_documents=True,# verbose=True, + callbacks=callback_manager, + chain_type_kwargs={"prompt": prompt, "memory": memory},) + else: + qa = RetrievalQA.from_chain_type(llm=llm, + chain_type="stuff", # try other chains types as well. refine, map_reduce, map_rerank + retriever=retriever, + return_source_documents=True,# verbose=True, + callbacks=callback_manager, + chain_type_kwargs={"prompt": prompt,},) + + return qa + + +# chose device typ to run on as well as to show source documents. +@click.command() +@click.option( + "--device_type", + default="cuda" if torch.cuda.is_available() else "cpu", + type=click.Choice( + [ + "cpu", + "cuda", + "ipu", + "xpu", + "mkldnn", + "opengl", + "opencl", + "ideep", + "hip", + "ve", + "fpga", + "ort", + "xla", + "lazy", + "vulkan", + "mps", + "meta", + "hpu", + "mtia", + ], + ), + help="Device to run on. (Default is cuda)", +) +@click.option( + "--show_sources", + "-s", + is_flag=True, + help="Show sources along with answers (Default is False)", +) +@click.option( + "--use_history", + "-h", + is_flag=True, + help="Use history (Default is False)", +) +def main(device_type, show_sources, use_history): + """ + Implements the main information retrieval task for a localGPT. + + This function sets up the QA system by loading the necessary embeddings, vectorstore, and LLM model. + It then enters an interactive loop where the user can input queries and receive answers. Optionally, + the source documents used to derive the answers can also be displayed. + + Parameters: + - device_type (str): Specifies the type of device where the model will run, e.g., 'cpu', 'mps', 'cuda', etc. + - show_sources (bool): Flag to determine whether to display the source documents used for answering. + - use_history (bool): Flag to determine whether to use chat history or not. + + Notes: + - Logging information includes the device type, whether source documents are displayed, and the use of history. + - If the models directory does not exist, it creates a new one to store models. + - The user can exit the interactive loop by entering "exit". + - The source documents are displayed if the show_sources flag is set to True. + + """ + + logging.info(f"Running on: {device_type}") + logging.info(f"Display Source Documents set to: {show_sources}") + logging.info(f"Use history set to: {use_history}") + + # check if models directory do not exist, create a new one and store models here. + if not os.path.exists("./models"): + os.mkdir("models") + + qa = retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama") + # Interactive questions and answers + while True: + + query = input("\nEnter a query: ") + if query == "exit": + break + # Get the answer from the chain + res = qa(query) + answer, docs = res["result"], res["source_documents"] + + # Print the result + print("\n\n> Question:") + print(query) + print("\n> Answer:") + print(answer) + + if show_sources: # this is a flag that you can set to disable showing answers. + # # Print the relevant sources used for the answer + print("----------------------------------SOURCE DOCUMENTS---------------------------") + for document in docs: + print("\n> " + document.metadata["source"] + ":") + print(document.page_content) + print("----------------------------------SOURCE DOCUMENTS---------------------------") + + +if __name__ == "__main__": + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", level=logging.INFO + ) + main()