Skip to content

Commit

Permalink
Add registry functions to instantiate models by provider (#428)
Browse files Browse the repository at this point in the history
* Add provider-specific registry functions.

* Update model registry handles used in tests.

* Update readme and usage examples.

* Update spacy_llm/models/rest/openai/registry.py

Co-authored-by: Sofie Van Landeghem <[email protected]>

* Fix HF registry return type.

* Fix GPU test error message regexes.

* Fix tests. Bump default OAI model to GPT-4.

* Fix external tests.

* Format.

* Ignore LangChain deprecation warning. Ease sentiment tests.

* Use GPT-4 for sharding spancat test case.

* Relax EL test. Remove unnecessary warning contexts.

* Fix comparison in EL test.

* Fix GPU tests.

---------

Co-authored-by: Sofie Van Landeghem <[email protected]>
  • Loading branch information
rmitsch and svlandeg authored Apr 24, 2024
1 parent c87d5a6 commit ff07682
Show file tree
Hide file tree
Showing 46 changed files with 363 additions and 97 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ factory = "llm"
labels = ["COMPLIMENT", "INSULT"]

[components.llm.model]
@llm_models = "spacy.GPT-4.v2"
@llm_models = "spacy.OpenAI.v1"
name = "gpt-4"
```

Now run:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ filterwarnings = [
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*"
"ignore:^.*was deprecated in langchain-community.*",
"ignore:^.*was deprecated in LangChain 0.0.1.*",
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ langchain>=0.1,<0.2; python_version>="3.9"
openai>=0.27,<=0.28.1; python_version>="3.9"

# Necessary for running all local models on GPU.
transformers[sentencepiece]>=4.0.0
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
transformers[sentencepiece]>=4.0.0,<=4.38
torch
einops>=0.4

Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .registry import huggingface_v1
from .stablelm import stablelm_hf

__all__ = [
"HuggingFace",
"dolly_hf",
"falcon_hf",
"huggingface_v1",
"llama2_hf",
"mistral_hf",
"openllama_hf",
Expand Down
3 changes: 1 addition & 2 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def mistral_hf(
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
the raw responses.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
51 changes: 51 additions & 0 deletions spacy_llm/models/hf/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional

from confection import SimpleFrozenDict

from ...registry import registry
from .base import HuggingFace
from .dolly import Dolly
from .falcon import Falcon
from .llama2 import Llama2
from .mistral import Mistral
from .openllama import OpenLLaMA
from .stablelm import StableLM


@registry.llm_models("spacy.HF.v1")
@registry.llm_models("spacy.HuggingFace.v1")
def huggingface_v1(
name: str,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> HuggingFace:
"""Returns HuggingFace model instance.
name (str): Name of model to use.
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return
the raw responses.
"""
model_context_lengths = {
Dolly: 2048,
Falcon: 2048,
Llama2: 4096,
Mistral: 8000,
OpenLLaMA: 2048,
StableLM: 4096,
}

for model_cls, context_length in model_context_lengths.items():
model_names = getattr(model_cls, "MODEL_NAMES")
if model_names and name in model_names.__args__:
return model_cls(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

raise ValueError(
f"Name {name} could not be associated with any of the supported models. Please check "
f"https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct."
)
1 change: 1 addition & 0 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def query_langchain(
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
assert callable(model)
return [
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
]
Expand Down
37 changes: 37 additions & 0 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Anthropic.v1")
def anthropic_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Anthropic:
"""Returns Anthropic model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Anthropic): Instance of Anthropic model.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Claude-2.v2")
def anthropic_claude_2_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down
39 changes: 38 additions & 1 deletion spacy_llm/models/rest/cohere/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Cohere, Endpoints


@registry.llm_models("spacy.Cohere.v1")
def cohere_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Cohere.DEFAULT_STRICT,
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
interval: float = Cohere.DEFAULT_INTERVAL,
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Cohere:
"""Returns Cohere model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Cohere): Instance of Cohere model.
"""
return Cohere(
name=name,
endpoint=Endpoints.COMPLETION.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Command.v2")
def cohere_command_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down Expand Up @@ -56,7 +93,7 @@ def cohere_command(
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Cohere instance for 'command' model using REST to prompt API.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
41 changes: 41 additions & 0 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,47 @@

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.OpenAI.v1")
def openai_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> OpenAI:
"""Returns OpenAI model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (OpenAI): OpenAI model instance.
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
Expand Down
46 changes: 44 additions & 2 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,48 @@
from .model import Endpoints, PaLM


@registry.llm_models("spacy.Google.v1")
def google_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
endpoint: Optional[str] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google model instance using REST to prompt API.
name (str): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
RETURNS (PaLM): PaLM model instance.
"""
default_endpoint = (
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
)
return PaLM(
name=name,
endpoint=endpoint or default_endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=None,
)


@registry.llm_models("spacy.PaLM.v2")
def palm_bison_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
Expand All @@ -18,7 +60,7 @@ def palm_bison_v2(
context_length: Optional[int] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down Expand Up @@ -57,7 +99,7 @@ def palm_bison(
endpoint: Optional[str] = None,
) -> PaLM:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
7 changes: 4 additions & 3 deletions spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger.addHandler(logging.NullHandler())

DEFAULT_MODEL_CONFIG = {
"@llm_models": "spacy.GPT-3-5.v2",
"@llm_models": "spacy.GPT-3-5.v3",
"strict": True,
}
DEFAULT_CACHE_CONFIG = {
Expand Down Expand Up @@ -238,6 +238,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
else self._task.generate_prompts(noncached_doc_batch),
n_iters + 1,
)

responses_iters = tee(
self._model(
# Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped
Expand All @@ -251,7 +252,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
)

for prompt_data, response, doc in zip(
prompts_iters[1], responses_iters[0], noncached_doc_batch
prompts_iters[1], list(responses_iters[0]), noncached_doc_batch
):
logger.debug(
"Generated prompt for doc: %s\n%s",
Expand All @@ -266,7 +267,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
elem[1] if support_sharding else noncached_doc_batch[i]
for i, elem in enumerate(prompts_iters[2])
),
responses_iters[1],
list(responses_iters[1]),
)
)

Expand Down
2 changes: 1 addition & 1 deletion spacy_llm/tests/models/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_cohere_api_response_when_error():
def test_cohere_error_unsupported_model():
"""Ensure graceful handling of error when model is not supported"""
incorrect_model = "x-gpt-3.5-turbo"
with pytest.raises(ValueError, match="model not found"):
with pytest.raises(ValueError, match="Request to Cohere API failed"):
Cohere(
name=incorrect_model,
config={},
Expand Down
Loading

0 comments on commit ff07682

Please sign in to comment.