Skip to content

Commit

Permalink
support llms in var-gen
Browse files Browse the repository at this point in the history
  • Loading branch information
crazycth authored and yje-arch committed Sep 13, 2023
1 parent dbf0f35 commit 7021db0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ strategies.
| Anthropic |||||
| MidJourney | || ||

To support different models in custom func(e.g. Model Comparison) , [follow our example](https://github.com/YiVal/YiVal/blob/litellm_complete/demo/configs/model_compare.yml)

To support different models in evaluators and generators , [check our config](https://github.com/YiVal/YiVal/blob/litellm_complete/demo/configs/headline_generation.yml)

## Installation

```sh
Expand Down
9 changes: 6 additions & 3 deletions demo/configs/headline_generation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ dataset:
openai_prompt_data_generator:
chunk_size: 100000
diversify: true
# model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3
model_name: gpt-4
prompt:
"Please provide a concrete and realistic test case as a dictionary for function invocation using the ** operator.
Expand All @@ -18,7 +19,7 @@ dataset:
name: headline_generation_for_business
parameters:
tech_startup_business: str
number_of_examples: 1
number_of_examples: 2
output_path: null
source_type: machine_generated

Expand All @@ -31,8 +32,9 @@ variations:
variation_id: null
generator_name: openai_prompt_based_variation_generator
generator_config:
openai_model_name: gpt-4
number_of_variations: 1
#model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3
model_name: gpt-4
number_of_variations: 2
diversify: true
variables: null
prompt:
Expand Down Expand Up @@ -65,6 +67,7 @@ evaluators:
D It meets the criterion very well.
E It meets the criterion exceptionally well, with little to no room for improvement.
choices: ["A", "B", "C", "D", "E"]
# model_name specify the llm model , e.g. a16z-infra/llama-2-13b-chat:9dff94b1bed5af738655d4a7cbcdcde2bd503aa85c94334fe1f42af7f3dd5ee3
model_name: gpt-4
description: "evaluate the quality of the landing page headline"
scale_description: "0-4"
Expand Down
2 changes: 1 addition & 1 deletion src/yival/schemas/varation_generator_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpenAIPromptBasedVariationGeneratorConfig(BaseVariationGeneratorConfig):
"""
Generate variation using chatgpt. Currently only support openai models.
"""
openai_model_name: str = "gpt-4"
model_name: str = "gpt-4"
prompt: Union[str, List[Dict[str, str]]] = ""
diversify: bool = False
variables: Optional[List[str]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pickle
from typing import Any, Dict, Iterator, List

import openai
from tqdm import tqdm

from ..common import utils
from ..common.model_utils import llm_completion
from ..schemas.experiment_config import WrapperVariation
from ..schemas.model_configs import Request
from ..schemas.varation_generator_configs import (
OpenAIPromptBasedVariationGeneratorConfig,
)
Expand Down Expand Up @@ -109,7 +110,7 @@ def generate_variations(self) -> Iterator[List[WrapperVariation]]:
responses = asyncio.run(
utils.parallel_completions(
message_batches,
self.config.openai_model_name,
self.config.model_name,
self.config.max_tokens,
pbar=pbar
)
Expand All @@ -131,13 +132,17 @@ def generate_variations(self) -> Iterator[List[WrapperVariation]]:
desc="Generating Variations",
unit="variation"
) as pbar:
output = openai.ChatCompletion.create(
model=self.config.openai_model_name,
messages=messages,
temperature=1,
presence_penalty=1,
max_tokens=self.config.max_tokens,
)
output = llm_completion(
Request(
model_name=self.config.model_name,
prompt=messages,
params={
"temperature": 1,
"presence_penalty": 1,
"max_tokens": self.config.max_tokens,
}
)
).output
if self.config.variables and not validate_output(
output.choices[0].message.content,
self.config.variables
Expand Down

0 comments on commit 7021db0

Please sign in to comment.