diff --git a/intermodel/callgpt.py b/intermodel/callgpt.py index 2b19cf6..2d91cb2 100755 --- a/intermodel/callgpt.py +++ b/intermodel/callgpt.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import dataclasses import re import traceback import asyncio @@ -38,7 +39,7 @@ stop=tenacity.stop_after_attempt(6), ) async def complete( - model, + raw_model, prompt=None, temperature=None, top_p=None, @@ -56,7 +57,7 @@ async def complete( vendor_config=None, **kwargs, ): - model = MODEL_ALIASES.get(model, model) + model = parse_model_string(MODEL_ALIASES.get(raw_model, raw_model)).model # todo: multiple completions, top k, logit bias for all vendors # todo: detect model not found on all vendors and throw the same exception if vendor is None: @@ -314,7 +315,7 @@ async def complete( return {"prompt": {"text": prompt}, "completions": {}} elif vendor == "fake-local": return intermodel.callgpt_faker.fake_local( - model=model, + model=parse_model_string(raw_model).tokenize_as, vendor=vendor, prompt=prompt, max_tokens=max_tokens, @@ -404,7 +405,7 @@ def complete_sync(*args, **kwargs): def tokenize(model: str, string: str) -> List[int]: import tiktoken - model = MODEL_ALIASES.get(model, model) + model = parse_model_string(MODEL_ALIASES.get(model, model)).tokenize_as try: vendor = pick_vendor(model) except NotImplementedError: @@ -443,6 +444,27 @@ def tokenize(model: str, string: str) -> List[int]: return tokenizer.encode(string).ids +@dataclasses.dataclass +class Model: + model: str + tokenize_as: str + max_token_len: int + + +def parse_model_string(model: str): + # returns: model, tokenization model, max token length + match = re.match(r"^(.+?)(?:\^(.+?))?(?:@(\d+))?$", model) + return Model( + model=match.group(1), + tokenize_as=match.group(2) or match.group(1), + max_token_len=( + int(match.group(3)) + if match.group(3) is not None + else max_token_length_inner(model) + ), + ) + + def count_tokens(model: str, string: str) -> int: return len(tokenize(model, string)) @@ -516,6 +538,10 @@ def pick_vendor(model, custom_config=None): def max_token_length(model): + return parse_model_string(model).max_token_len + + +def max_token_length_inner(model): """ The maximum number of tokens in the prompt and completion """ diff --git a/intermodel/test_model_string.py b/intermodel/test_model_string.py new file mode 100644 index 0000000..0cde034 --- /dev/null +++ b/intermodel/test_model_string.py @@ -0,0 +1,25 @@ +from intermodel.callgpt import parse_model_string, Model, max_token_length_inner + + +def test_plain(): + assert parse_model_string("davinci-002") == Model( + "davinci-002", "davinci-002", max_token_length_inner("davinci-002") + ) + + +def test_both(): + assert parse_model_string("davinci-002^davinci@32000") == Model( + "davinci-002", "davinci", 32000 + ) + + +def test_tokenizer(): + assert parse_model_string("davinci-002^davinci") == Model( + "davinci-002", "davinci", max_token_length_inner("davinci") + ) + + +def test_length(): + assert parse_model_string("davinci-002@12") == Model( + "davinci-002", "davinci-0022", 12 + ) diff --git a/pyproject.toml b/pyproject.toml index ff2ef82..70a2fc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "intermodel" -version = "0.0.49" +version = "0.0.50" dependencies = [ "openai==0.28", "httpx",