Skip to content

Commit

Permalink
ggml quant cpu mps support
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff committed Jun 28, 2023
1 parent 925d63c commit 60e7ee1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Natural Language Processing
langchain==0.0.191
chromadb==0.3.22
llama-cpp-python==0.1.48
llama-cpp-python==0.1.66
pdfminer.six==20221105
InstructorEmbedding
sentence-transformers
Expand Down
69 changes: 45 additions & 24 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import click
import torch
from auto_gptq import AutoGPTQForCausalLM
from huggingface_hub import hf_hub_download
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.llms import HuggingFacePipeline, LlamaCpp

# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
Expand Down Expand Up @@ -39,33 +40,45 @@ def load_model(device_type, model_id, model_basename=None):
Raises:
ValueError: If an unsupported model or device type is provided.
"""
if device_type.lower() in ["cpu", "mps"]:
model_basename = None

logging.info(f"Loading Model: {model_id}, on: {device_type}")
logging.info("This action can take a few minutes!")

if model_basename is not None:
# The code supports all huggingface models that ends with GPTQ and have some variation
# of .no-act.order or .safetensors in their HF repo.
logging.info("Using AutoGPTQForCausalLM for quantized models")

if ".safetensors" in model_basename:
# Remove the ".safetensors" ending if present
model_basename = model_basename.replace(".safetensors", "")

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
logging.info("Tokenizer loaded")

model = AutoGPTQForCausalLM.from_quantized(
model_id,
model_basename=model_basename,
use_safetensors=True,
trust_remote_code=True,
device="cuda:0",
use_triton=False,
quantize_config=None,
)
if device_type.lower() in ["cpu", "mps"]:
logging.info("Using Llamacpp for quantized models")
model_path = hf_hub_download(repo_id=model_id, filename=model_basename)
if device_type.lower() == "mps":
return LlamaCpp(
model_path=model_path,
n_ctx=2048,
max_tokens=2048,
temperature=0,
repeat_penalty=1.15,
n_gpu_layers=1000,
)
return LlamaCpp(model_path=model_path, n_ctx=2048, max_tokens=2048, temperature=0, repeat_penalty=1.15)

else:
# The code supports all huggingface models that ends with GPTQ and have some variation
# of .no-act.order or .safetensors in their HF repo.
logging.info("Using AutoGPTQForCausalLM for quantized models")

if ".safetensors" in model_basename:
# Remove the ".safetensors" ending if present
model_basename = model_basename.replace(".safetensors", "")

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
logging.info("Tokenizer loaded")

model = AutoGPTQForCausalLM.from_quantized(
model_id,
model_basename=model_basename,
use_safetensors=True,
trust_remote_code=True,
device="cuda:0",
use_triton=False,
quantize_config=None,
)
elif (
device_type.lower() == "cuda"
): # The code supports all huggingface models that ends with -HF or which have a .bin
Expand Down Expand Up @@ -198,6 +211,14 @@ def main(device_type, show_sources):
# model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
# model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors"

# for GGML (quantized cpu+gpu+mps) models - check if they support llama.cpp
# model_id = "TheBloke/wizard-vicuna-13B-GGML"
# model_basename = "wizard-vicuna-13B.ggmlv3.q4_0.bin"
# model_basename = "wizard-vicuna-13B.ggmlv3.q6_K.bin"
# model_basename = "wizard-vicuna-13B.ggmlv3.q2_K.bin"
# model_id = "TheBloke/orca_mini_3B-GGML"
# model_basename = "orca-mini-3b.ggmlv3.q4_0.bin"

llm = load_model(device_type, model_id=model_id, model_basename=model_basename)

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
Expand Down

0 comments on commit 60e7ee1

Please sign in to comment.