Skip to content

Commit

Permalink
feat(model): pass kwargs to generate from model call (hyperonym#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorrr authored and peakji committed Jun 13, 2023
1 parent 5609282 commit 1677491
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __call__(
n=1,
logprobs=0,
echo=False,
**kwargs,
):
"""Create a completion stream for the provided prompt."""
if isinstance(prompt, str):
Expand Down Expand Up @@ -69,6 +70,17 @@ def __call__(
if echo:
yield map_choice(text, i, text_offset=offset, **samples)

generate_kwargs = {
**dict(
logprobs=logprobs,
min_new_tokens=min_tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
),
**kwargs
}

# Generate completion tokens.
for (
tokens,
Expand All @@ -78,11 +90,7 @@ def __call__(
status,
) in self.generate(
input_ids[None, :].repeat(n, 1),
logprobs=logprobs,
min_new_tokens=min_tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
**generate_kwargs,
):
for i in range(n):
# Check and update the finish status of the sequence.
Expand Down

0 comments on commit 1677491

Please sign in to comment.