Skip to content

Commit

Permalink
centralized definitions in constants.py
Browse files Browse the repository at this point in the history
Using load_model directly from run_localGPT.py and MODEL_ID & MODEL_BASENAME from constants.py
  • Loading branch information
PromtEngineer authored Aug 5, 2023
1 parent 1cb7b41 commit e8c2fb5
Showing 1 changed file with 2 additions and 25 deletions.
27 changes: 2 additions & 25 deletions run_localGPT_API.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from werkzeug.utils import secure_filename

from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY
from constants import CHROMA_SETTINGS, EMBEDDING_MODEL_NAME, PERSIST_DIRECTORY, MODEL_ID, MODEL_BASENAME

DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
SHOW_SOURCES = True
Expand Down Expand Up @@ -64,30 +64,7 @@

RETRIEVER = DB.as_retriever()

# for HF models
# model_id = "TheBloke/vicuna-7B-1.1-HF"
# model_id = "TheBloke/Wizard-Vicuna-7B-Uncensored-HF"
# model_id = "TheBloke/guanaco-7B-HF"
# model_id = 'NousResearch/Nous-Hermes-13b' # Requires ~ 23GB VRAM.
# Using STransformers alongside will 100% create OOM on 24GB cards.
# LLM = load_model(device_type=DEVICE_TYPE, model_id=model_id)

# for GPTQ (quantized) models
# model_id = "TheBloke/Nous-Hermes-13B-GPTQ"
# model_basename = "nous-hermes-13b-GPTQ-4bit-128g.no-act.order"
# model_id = "TheBloke/WizardLM-30B-Uncensored-GPTQ"
# model_basename = "WizardLM-30B-Uncensored-GPTQ-4bit.act-order.safetensors"
# Requires ~21GB VRAM. Using STransformers alongside can potentially create OOM on 24GB cards.
# model_id = "TheBloke/wizardLM-7B-GPTQ"
# model_basename = "wizardLM-7B-GPTQ-4bit.compat.no-act-order.safetensors"

# model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
# model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors"

model_id = "TheBloke/Llama-2-7B-Chat-GGML"
model_basename = "llama-2-7b-chat.ggmlv3.q4_0.bin"

LLM = load_model(device_type=DEVICE_TYPE, model_id=model_id, model_basename=model_basename)
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME)

QA = RetrievalQA.from_chain_type(
llm=LLM, chain_type="stuff", retriever=RETRIEVER, return_source_documents=SHOW_SOURCES
Expand Down

0 comments on commit e8c2fb5

Please sign in to comment.