Skip to content

Commit

Permalink
add model^tokenize_as@ctx_len
Browse files Browse the repository at this point in the history
  • Loading branch information
ampdot-io committed Sep 28, 2024
1 parent 8a5af73 commit 73dbe9a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
34 changes: 30 additions & 4 deletions intermodel/callgpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
import dataclasses
import re
import traceback
import asyncio
Expand Down Expand Up @@ -38,7 +39,7 @@
stop=tenacity.stop_after_attempt(6),
)
async def complete(
model,
raw_model,
prompt=None,
temperature=None,
top_p=None,
Expand All @@ -56,7 +57,7 @@ async def complete(
vendor_config=None,
**kwargs,
):
model = MODEL_ALIASES.get(model, model)
model = parse_model_string(MODEL_ALIASES.get(raw_model, raw_model)).model
# todo: multiple completions, top k, logit bias for all vendors
# todo: detect model not found on all vendors and throw the same exception
if vendor is None:
Expand Down Expand Up @@ -314,7 +315,7 @@ async def complete(
return {"prompt": {"text": prompt}, "completions": {}}
elif vendor == "fake-local":
return intermodel.callgpt_faker.fake_local(
model=model,
model=parse_model_string(raw_model).tokenize_as,
vendor=vendor,
prompt=prompt,
max_tokens=max_tokens,
Expand Down Expand Up @@ -404,7 +405,7 @@ def complete_sync(*args, **kwargs):
def tokenize(model: str, string: str) -> List[int]:
import tiktoken

model = MODEL_ALIASES.get(model, model)
model = parse_model_string(MODEL_ALIASES.get(model, model)).tokenize_as
try:
vendor = pick_vendor(model)
except NotImplementedError:
Expand Down Expand Up @@ -443,6 +444,27 @@ def tokenize(model: str, string: str) -> List[int]:
return tokenizer.encode(string).ids


@dataclasses.dataclass
class Model:
model: str
tokenize_as: str
max_token_len: int


def parse_model_string(model: str):
# returns: model, tokenization model, max token length
match = re.match(r"^(.+?)(?:\^(.+?))?(?:@(\d+))?$", model)
return Model(
model=match.group(1),
tokenize_as=match.group(2) or match.group(1),
max_token_len=(
int(match.group(3))
if match.group(3) is not None
else max_token_length_inner(model)
),
)


def count_tokens(model: str, string: str) -> int:
return len(tokenize(model, string))

Expand Down Expand Up @@ -516,6 +538,10 @@ def pick_vendor(model, custom_config=None):


def max_token_length(model):
return parse_model_string(model).max_token_len


def max_token_length_inner(model):
"""
The maximum number of tokens in the prompt and completion
"""
Expand Down
25 changes: 25 additions & 0 deletions intermodel/test_model_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from intermodel.callgpt import parse_model_string, Model, max_token_length_inner


def test_plain():
assert parse_model_string("davinci-002") == Model(
"davinci-002", "davinci-002", max_token_length_inner("davinci-002")
)


def test_both():
assert parse_model_string("davinci-002^davinci@32000") == Model(
"davinci-002", "davinci", 32000
)


def test_tokenizer():
assert parse_model_string("davinci-002^davinci") == Model(
"davinci-002", "davinci", max_token_length_inner("davinci")
)


def test_length():
assert parse_model_string("davinci-002@12") == Model(
"davinci-002", "davinci-0022", 12
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "intermodel"
version = "0.0.49"
version = "0.0.50"
dependencies = [
"openai==0.28",
"httpx",
Expand Down

0 comments on commit 73dbe9a

Please sign in to comment.