Skip to content

Commit

Permalink
Merge pull request BeastByteAI#110 from iryna-kondr/gguf
Browse files Browse the repository at this point in the history
replaced gpt4all with llama-cpp-python
  • Loading branch information
iryna-kondr authored Aug 4, 2024
2 parents a9c29c4 + 10c09d2 commit 261f501
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 70 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
]
name = "scikit-llm"
version = "1.3.1"
version = "1.4.0"
authors = [
{ name="Oleh Kostromin", email="[email protected]" },
{ name="Iryna Kondrashchenko", email="[email protected]" },
Expand All @@ -27,7 +27,7 @@ classifiers = [
]

[project.optional-dependencies]
gpt4all = ["gpt4all>=2.0.0,<3.0.0"]
gguf = ["llama-cpp-python>=0.2.82,<0.2.83"]
annoy = ["annoy>=1.17.2,<2.0.0"]

[tool.ruff]
Expand Down
2 changes: 1 addition & 1 deletion skllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.3.1'
__version__ = '1.4.0'
__author__ = 'Iryna Kondrashchenko, Oleh Kostromin'
37 changes: 36 additions & 1 deletion skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
_GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
_GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
_GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"


class SKLLMConfig:
Expand Down Expand Up @@ -169,4 +172,36 @@ def get_gpt_url() -> Optional[str]:
@staticmethod
def reset_gpt_url():
"""Resets the GPT URL."""
os.environ.pop(_GPT_URL_VAR, None)
os.environ.pop(_GPT_URL_VAR, None)

@staticmethod
def get_gguf_download_path() -> str:
"""Gets the path to store the downloaded GGUF files."""
default_path = os.path.join(os.path.expanduser("~"), ".skllm", "gguf")
return os.environ.get(_GGUF_DOWNLOAD_PATH, default_path)

@staticmethod
def get_gguf_max_gpu_layers() -> int:
"""Gets the maximum number of layers to use for the GGUF model."""
return int(os.environ.get(_GGUF_MAX_GPU_LAYERS, 0))

@staticmethod
def set_gguf_max_gpu_layers(n_layers: int):
"""Sets the maximum number of layers to use for the GGUF model."""
if not isinstance(n_layers, int):
raise ValueError("n_layers must be an integer")
if n_layers < -1:
n_layers = -1
os.environ[_GGUF_MAX_GPU_LAYERS] = str(n_layers)

@staticmethod
def set_gguf_verbose(verbose: bool):
"""Sets the verbosity of the GGUF model."""
if not isinstance(verbose, bool):
raise ValueError("verbose must be a boolean")
os.environ[_GGUF_VERBOSE] = str(verbose)

@staticmethod
def get_gguf_verbose() -> bool:
"""Gets the verbosity of the GGUF model."""
return os.environ.get(_GGUF_VERBOSE, "False").lower() == "true"
54 changes: 0 additions & 54 deletions skllm/llm/gpt/clients/gpt4all/completion.py

This file was deleted.

12 changes: 12 additions & 0 deletions skllm/llm/gpt/clients/llama_cpp/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from skllm.llm.gpt.clients.llama_cpp.handler import ModelCache, LlamaHandler


def get_chat_completion(messages: dict, model: str, **kwargs):

with ModelCache.lock:
handler = ModelCache.get(model)
if handler is None:
handler = LlamaHandler(model)
ModelCache.store(model, handler)

return handler.get_chat_completion(messages, **kwargs)
164 changes: 164 additions & 0 deletions skllm/llm/gpt/clients/llama_cpp/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import threading
import os
import hashlib
import requests
from tqdm import tqdm
import hashlib
from typing import Optional
import tempfile
from skllm.config import SKLLMConfig
from warnings import warn


try:
from llama_cpp import Llama as _Llama

_llama_imported = True
except (ImportError, ModuleNotFoundError):
_llama_imported = False


supported_models = {
"llama3-8b-q4": {
"download_url": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf",
"sha256": "c57380038ea85d8bec586ec2af9c91abc2f2b332d41d6cf180581d7bdffb93c1",
"n_ctx": 8192,
"supports_system_message": True,
},
"gemma2-9b-q4": {
"download_url": "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf",
"sha256": "13b2a7b4115bbd0900162edcebe476da1ba1fc24e718e8b40d32f6e300f56dfe",
"n_ctx": 8192,
"supports_system_message": False,
},
"phi3-mini-q4": {
"download_url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf",
"sha256": "8a83c7fb9049a9b2e92266fa7ad04933bb53aa1e85136b7b30f1b8000ff2edef",
"n_ctx": 4096,
"supports_system_message": False,
},
"mistral0.3-7b-q4": {
"download_url": "https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf",
"sha256": "1270d22c0fbb3d092fb725d4d96c457b7b687a5f5a715abe1e818da303e562b6",
"n_ctx": 32768,
"supports_system_message": False,
},
"gemma2-2b-q6": {
"download_url": "https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K_L.gguf",
"sha256": "b2ef9f67b38c6e246e593cdb9739e34043d84549755a1057d402563a78ff2254",
"n_ctx": 8192,
"supports_system_message": False,
},
}


class LlamaHandler:

def maybe_download_model(self, model_name, download_url, sha256) -> str:
download_folder = SKLLMConfig.get_gguf_download_path()
os.makedirs(download_folder, exist_ok=True)
model_name = model_name + ".gguf"
model_path = os.path.join(download_folder, model_name)
if not os.path.exists(model_path):
print("The model `{0}` is not found locally.".format(model_name))
self._download_model(model_name, download_folder, download_url, sha256)
return model_path

def _download_model(
self, model_filename: str, model_path: str, url: str, expected_sha256: str
) -> str:
full_path = os.path.join(model_path, model_filename)
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=model_path)
temp_path = temp_file.name
temp_file.close()

response = requests.get(url, stream=True)

if response.status_code != 200:
os.remove(temp_path)
raise ValueError(
f"Request failed: HTTP {response.status_code} {response.reason}"
)

total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 * 1024 * 4

sha256 = hashlib.sha256()

with (
open(temp_path, "wb") as file,
tqdm(
desc="Downloading {0}: ".format(model_filename),
total=total_size_in_bytes,
unit="iB",
unit_scale=True,
) as progress_bar,
):
for data in response.iter_content(block_size):
file.write(data)
sha256.update(data)
progress_bar.update(len(data))

downloaded_sha256 = sha256.hexdigest()
if downloaded_sha256 != expected_sha256:
raise ValueError(
f"Expected SHA-256 hash {expected_sha256}, but got {downloaded_sha256}"
)

os.rename(temp_path, full_path)

def __init__(self, model: str):
if not _llama_imported:
raise ImportError(
"llama_cpp is not installed, try `pip install scikit-llm[llama_cpp]`"
)
self.lock = threading.Lock()
if model not in supported_models:
raise ValueError(f"Model {model} is not supported.")
download_url = supported_models[model]["download_url"]
sha256 = supported_models[model]["sha256"]
n_ctx = supported_models[model]["n_ctx"]
self.supports_system_message = supported_models[model][
"supports_system_message"
]
if not self.supports_system_message:
warn(
f"The model {model} does not support system messages. This may cause issues with some estimators."
)
extended_model_name = model + "-" + sha256[:8]
model_path = self.maybe_download_model(
extended_model_name, download_url, sha256
)
max_gpu_layers = SKLLMConfig.get_gguf_max_gpu_layers()
verbose = SKLLMConfig.get_gguf_verbose()
self.model = _Llama(
model_path=model_path,
n_ctx=n_ctx,
verbose=verbose,
n_gpu_layers=max_gpu_layers,
)

def get_chat_completion(self, messages: dict, **kwargs):
if not self.supports_system_message:
messages = [m for m in messages if m["role"] != "system"]
with self.lock:
return self.model.create_chat_completion(
messages, temperature=0.0, **kwargs
)


class ModelCache:
lock = threading.Lock()
cache: dict[str, LlamaHandler] = {}

@classmethod
def get(cls, key) -> Optional[LlamaHandler]:
return cls.cache.get(key, None)

@classmethod
def store(cls, key, value):
cls.cache[key] = value

@classmethod
def clear(cls):
cls.cache = {}
17 changes: 11 additions & 6 deletions skllm/llm/gpt/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from skllm.llm.gpt.clients.openai.completion import (
get_chat_completion as _oai_get_chat_completion,
)
from skllm.llm.gpt.clients.gpt4all.completion import (
get_chat_completion as _g4a_get_chat_completion,
from skllm.llm.gpt.clients.llama_cpp.completion import (
get_chat_completion as _llamacpp_get_chat_completion,
)
from skllm.llm.gpt.utils import split_to_api_and_model
from skllm.config import SKLLMConfig as _Config


def get_chat_completion(
messages: dict,
openai_key: str = None,
Expand All @@ -17,14 +18,18 @@ def get_chat_completion(
):
"""Gets a chat completion from the OpenAI compatible API."""
api, model = split_to_api_and_model(model)
if api == "gpt4all":
return _g4a_get_chat_completion(messages, model)
if api == "gguf":
return _llamacpp_get_chat_completion(messages, model)
else:
url = _Config.get_gpt_url()
if api == "openai" and url is not None:
warnings.warn(f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`.")
warnings.warn(
f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`."
)
elif api == "custom_url" and url is None:
raise ValueError("You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url(<url>)`.")
raise ValueError(
"You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url(<url>)`."
)
return _oai_get_chat_completion(
messages,
openai_key,
Expand Down
6 changes: 6 additions & 0 deletions skllm/llm/gpt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def _get_openai_key(self) -> str:
key = self.key
if key is None:
key = _Config.get_openai_key()
if (
hasattr(self, "model")
and isinstance(self.model, str)
and self.model.startswith("gguf::")
):
key = "gguf"
if key is None:
raise RuntimeError("OpenAI key was not found")
return key
Expand Down
4 changes: 2 additions & 2 deletions skllm/llm/gpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple

SUPPORTED_APIS = ["openai", "azure", "gpt4all", "custom_url"]
SUPPORTED_APIS = ["openai", "azure", "gguf", "custom_url"]


def split_to_api_and_model(model: str) -> Tuple[str, str]:
Expand All @@ -9,4 +9,4 @@ def split_to_api_and_model(model: str) -> Tuple[str, str]:
for api in SUPPORTED_APIS:
if model.startswith(f"{api}::"):
return api, model[len(api) + 2 :]
raise ValueError(f"Unsupported API: {model.split('::')[0]}")
raise ValueError(f"Unsupported API: {model.split('::')[0]}")
12 changes: 8 additions & 4 deletions skllm/models/_base/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,14 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
warnings.warn(
"Passing num_workers to predict is temporary and will be removed in the future."
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
predictions = list(
tqdm(executor.map(self._predict_single, X), total=len(X))
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:
predictions = list(
tqdm(executor.map(self._predict_single, X), total=len(X))
)
else:
predictions = []
for x in tqdm(X):
predictions.append(self._predict_single(x))

return np.array(predictions)

Expand Down

0 comments on commit 261f501

Please sign in to comment.